{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## VAE MNIST example: BO in a latent space"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In this tutorial, we use the MNIST dataset and some standard PyTorch examples to show a synthetic problem where the input to the objective function is a `28 x 28` image. The main idea is to train a [variational auto-encoder (VAE)](https://arxiv.org/abs/1312.6114) on the MNIST dataset and run Bayesian Optimization in the latent space. We also refer readers to [this tutorial](http://krasserm.github.io/2018/04/07/latent-space-optimization/), which discusses [the method](https://arxiv.org/abs/1610.02415) of jointly training a VAE with a predictor (e.g., classifier), and shows a similar tutorial for the MNIST setting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Install dependencies if we are running in colab\n",
        "import sys\n",
        "if 'google.colab' in sys.modules:\n",
        "    %pip install botorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "import os\n",
        "import torch\n",
        "\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torchvision import datasets  # transforms\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "dtype = torch.double\n",
        "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\", False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Problem setup\n",
        "\n",
        "Let's first define our synthetic expensive-to-evaluate objective function. We assume that it takes the following form:\n",
        "\n",
        "$$\\text{image} \\longrightarrow \\text{image classifier} \\longrightarrow \\text{scoring function} \n",
        "\\longrightarrow \\text{score}.$$\n",
        "\n",
        "The classifier is a convolutional neural network (CNN) trained using the architecture of the [PyTorch CNN example](https://github.com/pytorch/examples/tree/master/mnist)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Net, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
        "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
        "        self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
        "        self.fc2 = nn.Linear(500, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = F.relu(self.conv1(x))\n",
        "        x = F.max_pool2d(x, 2, 2)\n",
        "        x = F.relu(self.conv2(x))\n",
        "        x = F.max_pool2d(x, 2, 2)\n",
        "        x = x.view(-1, 4 * 4 * 50)\n",
        "        x = F.relu(self.fc1(x))\n",
        "        x = self.fc2(x)\n",
        "        return F.log_softmax(x, dim=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "def get_pretrained_dir() -> str:\n",
        "    \"\"\"\n",
        "    Get the directory of pretrained models, which are in the BoTorch repo.\n",
        "\n",
        "    Returns the location specified by PRETRAINED_LOCATION if that env\n",
        "    var is set; otherwise checks if we are in a likely part of the BoTorch\n",
        "    repo (botorch/botorch or botorch/tutorials) and returns the right path.\n",
        "    \"\"\"\n",
        "    if \"PRETRAINED_LOCATION\" in os.environ.keys():\n",
        "        return os.environ[\"PRETRAINED_LOCATION\"]\n",
        "    cwd = os.getcwd()\n",
        "    folder = os.path.basename(cwd)\n",
        "    # automated tests run from botorch folder\n",
        "    if folder == \"botorch\":\n",
        "        return os.path.join(cwd, \"tutorials/pretrained_models/\")\n",
        "    # typical case (running from tutorial folder)\n",
        "    elif folder == \"tutorials\":\n",
        "        return os.path.join(cwd, \"pretrained_models/\")\n",
        "    raise FileNotFoundError(\"Could not figure out location of pretrained models.\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "cnn_weights_path = os.path.join(get_pretrained_dir(), \"mnist_cnn.pt\")\n",
        "cnn_model = Net().to(dtype=dtype, device=device)\n",
        "cnn_state_dict = torch.load(cnn_weights_path, map_location=device, weights_only=True)\n",
        "cnn_model.load_state_dict(cnn_state_dict);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Our VAE model follows the [PyTorch VAE example](https://github.com/pytorch/examples/tree/master/vae), except that we use the same data transform from the CNN tutorial for consistency. We then instantiate the model and again load a pre-trained model. To train these models, we refer readers to the PyTorch Github repository. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "class VAE(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.fc1 = nn.Linear(784, 400)\n",
        "        self.fc21 = nn.Linear(400, 20)\n",
        "        self.fc22 = nn.Linear(400, 20)\n",
        "        self.fc3 = nn.Linear(20, 400)\n",
        "        self.fc4 = nn.Linear(400, 784)\n",
        "\n",
        "    def encode(self, x):\n",
        "        h1 = F.relu(self.fc1(x))\n",
        "        return self.fc21(h1), self.fc22(h1)\n",
        "\n",
        "    def reparameterize(self, mu, logvar):\n",
        "        std = torch.exp(0.5 * logvar)\n",
        "        eps = torch.randn_like(std)\n",
        "        return mu + eps * std\n",
        "\n",
        "    def decode(self, z):\n",
        "        h3 = F.relu(self.fc3(z))\n",
        "        return torch.sigmoid(self.fc4(h3))\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu, logvar = self.encode(x.view(-1, 784))\n",
        "        z = self.reparameterize(mu, logvar)\n",
        "        return self.decode(z), mu, logvar\n",
        "\n",
        "vae_weights_path = os.path.join(get_pretrained_dir(), \"mnist_vae.pt\")\n",
        "vae_model = VAE().to(dtype=dtype, device=device)\n",
        "vae_state_dict = torch.load(vae_weights_path, map_location=device, weights_only=True)\n",
        "vae_model.load_state_dict(vae_state_dict);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We now define the scoring function that maps digits to scores. The function below prefers the digit '3'."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [],
      "source": [
        "def score(y):\n",
        "    \"\"\"Returns a 'score' for each digit from 0 to 9. It is modeled as a squared exponential\n",
        "    centered at the digit '3'.\n",
        "    \"\"\"\n",
        "    return torch.exp(-2 * (y - 3) ** 2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Given the scoring function, we can now write our overall objective, which as discussed above, starts with an image and outputs a score. Let's say the objective computes the expected score given the probabilities from the classifier."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [],
      "source": [
        "def score_image(x):\n",
        "    \"\"\"The input x is an image and an expected score\n",
        "    based on the CNN classifier and the scoring\n",
        "    function is returned.\n",
        "    \"\"\"\n",
        "    with torch.no_grad():\n",
        "        probs = torch.exp(cnn_model(x))  # b x 10\n",
        "        scores = score(\n",
        "            torch.arange(10, device=device, dtype=dtype)\n",
        "        ).expand(probs.shape)\n",
        "    return (probs * scores).sum(dim=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Finally, we define a helper function `decode` that takes as input the parameters `mu` and `logvar` of the variational distribution and performs reparameterization and the decoding. We use batched Bayesian optimization to search over the parameters `mu` and `logvar`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "def decode(train_x):\n",
        "    with torch.no_grad():\n",
        "        decoded = vae_model.decode(train_x)\n",
        "    return decoded.view(train_x.shape[0], 1, 28, 28)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### Model initialization and initial random batch\n",
        "\n",
        "We use a `SingleTaskGP` to model the score of an image generated by a latent representation. The model is initialized with points drawn from $[-6, 6]^{20}$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "<stdin>:1:10: fatal error: 'omp.h' file not found\n",
            "#include <omp.h>\n",
            "         ^~~~~~~\n",
            "1 error generated.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[KeOps] Warning : omp.h header is not in the path, disabling OpenMP.\n",
            "[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode\n"
          ]
        }
      ],
      "source": [
        "from botorch.models import SingleTaskGP\n",
        "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n",
        "from botorch.utils.transforms import normalize, unnormalize\n",
        "from botorch.models.transforms import Normalize\n",
        "\n",
        "d = 20\n",
        "bounds = torch.tensor([[-6.0] * d, [6.0] * d], device=device, dtype=dtype)\n",
        "\n",
        "\n",
        "def gen_initial_data(n=5):\n",
        "    # generate training data\n",
        "    train_x = unnormalize(\n",
        "        torch.rand(n, d, device=device, dtype=dtype),\n",
        "        bounds=bounds\n",
        "    )\n",
        "    train_obj = score_image(decode(train_x)).unsqueeze(-1)\n",
        "    best_observed_value = train_obj.max().item()\n",
        "    return train_x, train_obj, best_observed_value\n",
        "\n",
        "\n",
        "def get_fitted_model(train_x, train_obj, state_dict=None):\n",
        "    # initialize and fit model\n",
        "    model = SingleTaskGP(\n",
        "        train_X=normalize(train_x, bounds),\n",
        "        train_Y=train_obj,\n",
        "    )\n",
        "    if state_dict is not None:\n",
        "        model.load_state_dict(state_dict)\n",
        "    mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
        "    mll.to(train_x)\n",
        "    fit_gpytorch_mll(mll)\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### Define a helper function that performs the essential BO step\n",
        "The helper function below takes an acquisition function as an argument, optimizes it, and returns the batch $\\{x_1, x_2, \\ldots x_q\\}$ along with the observed function values. For this example, we'll use a small batch of $q=3$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {},
      "outputs": [],
      "source": [
        "from botorch.optim import optimize_acqf\n",
        "\n",
        "\n",
        "BATCH_SIZE = 3 if not SMOKE_TEST else 2\n",
        "NUM_RESTARTS = 10 if not SMOKE_TEST else 2\n",
        "RAW_SAMPLES = 256 if not SMOKE_TEST else 4\n",
        "\n",
        "\n",
        "def optimize_acqf_and_get_observation(acq_func):\n",
        "    \"\"\"Optimizes the acquisition function, and returns a\n",
        "    new candidate and a noisy observation\"\"\"\n",
        "\n",
        "    # optimize\n",
        "    candidates, _ = optimize_acqf(\n",
        "        acq_function=acq_func,\n",
        "        bounds=torch.stack(\n",
        "            [\n",
        "                torch.zeros(d, dtype=dtype, device=device),\n",
        "                torch.ones(d, dtype=dtype, device=device),\n",
        "            ]\n",
        "        ),\n",
        "        q=BATCH_SIZE,\n",
        "        num_restarts=NUM_RESTARTS,\n",
        "        raw_samples=RAW_SAMPLES,\n",
        "    )\n",
        "\n",
        "    # observe new values\n",
        "    new_x = unnormalize(candidates.detach(), bounds=bounds)\n",
        "    new_obj = score_image(decode(new_x)).unsqueeze(-1)\n",
        "    return new_x, new_obj"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Perform Bayesian Optimization loop with qEI\n",
        "The Bayesian optimization \"loop\" for a batch size of $q$ simply iterates the following steps: (1) given a surrogate model, choose a batch of points $\\{x_1, x_2, \\ldots x_q\\}$, (2) observe $f(x)$ for each $x$ in the batch, and (3) update the surrogate model. We run `N_BATCH=75` iterations. The acquisition function is approximated using `MC_SAMPLES=2048` samples. We also initialize the model with 5 randomly drawn points."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
          ]
        }
      ],
      "source": [
        "from botorch import fit_gpytorch_mll\n",
        "from botorch.acquisition.monte_carlo import qExpectedImprovement\n",
        "from botorch.sampling.normal import SobolQMCNormalSampler\n",
        "\n",
        "seed = 1\n",
        "torch.manual_seed(seed)\n",
        "\n",
        "N_BATCH = 25 if not SMOKE_TEST else 3\n",
        "best_observed = []\n",
        "\n",
        "# call helper function to initialize model\n",
        "train_x, train_obj, best_value = gen_initial_data(n=5)\n",
        "best_observed.append(best_value)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We are now ready to run the BO loop (this make take a few minutes, depending on your machine)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Running BO ........................."
          ]
        }
      ],
      "source": [
        "import warnings\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "\n",
        "print(f\"\\nRunning BO \", end=\"\")\n",
        "\n",
        "state_dict = None\n",
        "# run N_BATCH rounds of BayesOpt after the initial random batch\n",
        "for iteration in range(N_BATCH):\n",
        "\n",
        "    # fit the model\n",
        "    model = get_fitted_model(\n",
        "        train_x=train_x,\n",
        "        train_obj=train_obj,\n",
        "        state_dict=state_dict,\n",
        "    )\n",
        "\n",
        "    # define the qNEI acquisition function\n",
        "    qEI = qExpectedImprovement(\n",
        "        model=model, best_f=train_obj.max()\n",
        "    )\n",
        "\n",
        "    # optimize and get new observation\n",
        "    new_x, new_obj = optimize_acqf_and_get_observation(qEI)\n",
        "\n",
        "    # update training points\n",
        "    train_x = torch.cat((train_x, new_x))\n",
        "    train_obj = torch.cat((train_obj, new_obj))\n",
        "\n",
        "    # update progress\n",
        "    best_value = train_obj.max().item()\n",
        "    best_observed.append(best_value)\n",
        "\n",
        "    state_dict = model.state_dict()\n",
        "\n",
        "    print(\".\", end=\"\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "EI recommends the best point observed so far. We can visualize what the images corresponding to recommended points *would have* been if the BO process ended at various times. Here, we show the progress of the algorithm by examining the images at 0%, 10%, 25%, 50%, 75%, and 100% completion. The first image is the best image found through the initial random batch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAABGwAAADJCAYAAAB2bqQSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZD0lEQVR4nO3dXWzddf3A8U87WDf30DEeWipb3IUJJpih6zYLiCgNC0bChAuIF4IPPNmRwC6MU4GEkEzBCAGHJCIQL2CGC0bEBGMGG0L2QCcEELNgQmQKLZK4rgz2YPv7XyyUfz1nc6c9D99zvq9Xci727enp91veHvTj2e/XVhRFEQAAAAAko73RGwAAAABgMgMbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiDGwAAAAAEmNgAwAAAJAYAxsAAACAxJxQqxfesGFD3HXXXTE0NBRLly6N++67L1asWPE/v298fDzefvvtmDdvXrS1tdVqexBFUcTo6Gj09PREe/vUZpc6J3U6Jwc6Jwc6Jwc6JwcVdV7UwMaNG4uZM2cWDz30UPGXv/yluOaaa4oFCxYUw8PD//N79+zZU0SEh0fdHnv27NG5R8s/dO6Rw0PnHjk8dO6Rw0PnHjk8jqfztqIoiqiylStXxvLly+MXv/hFRByZVi5atChuvPHG+MEPfnDM7x0ZGYkFCxbEZz/72ZgxY0a1twYTxsbG4tVXX429e/dGZ2dnxd+vc5qBzsmBzsmBzsmBzslBJZ1X/a9EHTp0KHbt2hXr1q2bWGtvb4/+/v7Ytm1byfMPHjwYBw8enPjz6OhoRETMmDHDf1Coi6l85FHnNBudkwOdkwOdkwOdk4Pj6bzqFx1+7733YmxsLLq6uiatd3V1xdDQUMnz169fH52dnROPRYsWVXtLUHU6Jwc6Jwc6Jwc6Jwc6pxU1/C5R69ati5GRkYnHnj17Gr0lqDqdkwOdkwOdkwOdkwOd0wyq/leiTjnllJgxY0YMDw9PWh8eHo7u7u6S53d0dERHR0e1twE1pXNyoHNyoHNyoHNyoHNaUdU/YTNz5sxYtmxZbN68eWJtfHw8Nm/eHH19fdX+cdAQOicHOicHOicHOicHOqcVVf0TNhERa9eujauuuip6e3tjxYoVcc8998T+/fvjW9/6Vi1+HDSEzsmBzsmBzsmBzsmBzmk1NRnYXHHFFfGvf/0rbr311hgaGoqzzz47nn766ZILQEEz0zk50Dk50Dk50Dk50Dmtpq0oiqLRm/j/9u3bF52dnXH22We7nRo1NTY2Fi+//HKMjIzE/Pnz6/qzdU696Jwc6Jwc6Jwc6JwcVNJ5w+8SBQAAAMBkBjYAAAAAiTGwAQAAAEiMgQ0AAABAYgxsAAAAABJjYAMAAACQGAMbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiTmj0BqiO6667ruz6NddcU5Of99BDD5Vdv//++2vy82gN27ZtK1l77733yj73xRdfLFm7/fbbq74nqLYtW7aUrM2dO7f+Gymjt7e30VugRZTrPCKN1nVOtXg/Jwc6T5tP2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBi3CUqYYODg43ewlF9+9vfLrvuLlH5+e53v1uydv311x/3959++ull10866aQp7wmm6le/+lXZ9c997nN13kltlPv3ijsw5Klc663ceYTWc+P9nBzovPX5hA0AAABAYgxsAAAAABJjYAMAAACQGAMbAAAAgMS46HAiUr7AcCVcGCo/Dz74YMnaY489Vva5+/fvr/V24Li9+OKLJWttbW0N2EljuUBrayvXeYTWP6Lz1uD9/Ajv561N50fk1rlP2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBi3CWqznbu3NnoLUDNuRsUzSDHOytU4vnnny9ZO++88xqwE6ZD58dWrvMIrTcbnR+b9/PWoPNja9XOfcIGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiXHS4ztrbazMj6+3trej527dvL1k74QQ5APko9745ODhYt59VS9U4x6xZs6qwExrtaO1p/Qidtwbv58em89ag82Nr1c59wgYAAAAgMQY2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAAS47ZACdu1a1fZ9euuu27ar/3vf/+7ZO3UU0+d9usCNLN63xUBGkXrtDqNkwOdtz6fsAEAAABIjIENAAAAQGIMbAAAAAASU/HA5rnnnotLLrkkenp6oq2tLTZt2jTp60VRxK233hqnn356zJ49O/r7++ONN96o1n6hLnRODnRODnRODnRODnROjioe2Ozfvz+WLl0aGzZsKPv1O++8M+6999544IEHYseOHTFnzpxYtWpVHDhwYNqbbQUffPBB2Udvb2/J47rrriv7qIZTTz215MHHdE4OdN4a5syZU/LgYzpvDeU61/rHdN4aNH5sOm8NOq9MxXeJuvjii+Piiy8u+7WiKOKee+6JH//4x3HppZdGRMRvfvOb6Orqik2bNsWVV145vd1CneicHOicHOicHOicHOicHFX1GjZvvvlmDA0NRX9//8RaZ2dnrFy5MrZt21b2ew4ePBj79u2b9ICU6Zwc6Jwc6Jwc6Jwc6JxWVdWBzdDQUEREdHV1TVrv6uqa+Np/W79+fXR2dk48Fi1aVM0tQdXpnBzonBzonBzonBzonFbV8LtErVu3LkZGRiYee/bsafSWoOp0Tg50Tg50Tg50Tg50TjOo6sCmu7s7IiKGh4cnrQ8PD0987b91dHTE/PnzJz0gZTonBzonBzonBzonBzqnVVV80eFjWbJkSXR3d8fmzZvj7LPPjoiIffv2xY4dO+KGG26o5o9qWueff36jtxARRy7M9d/a2tqm/bq9vb3Tfo3U6Tw9J5xQ/q1s9uzZJWujo6O13k5L0Hl6vvGNb5RdX7t2bU1+nvdznTdKudZ1PnU6T4/38+rTeXp0Xh0VD2zef//9+Nvf/jbx5zfffDNefvnlWLhwYSxevDhuuummuOOOO+LTn/50LFmyJG655Zbo6emJ1atXV3PfUFM6Jwc6Jwc6Jwc6Jwc6J0cVD2wGBwfjy1/+8sSfP5qQXXXVVfHII4/E97///di/f39ce+21sXfv3jjvvPPi6aefjlmzZlVv11BjOicHOicHOicHOicHOidHFQ9sLrjggrJ/neYjbW1tcfvtt8ftt98+rY1BI+mcHOicHOicHOicHOicHDX8LlEAAAAATFbViw7TPJYvX16ytmPHjrLPnTFjRsnaOeecU/U9wf8yODhY0fNXrFhRo51AdVTadDnlOj/aReTLvc8/9dRT094DHEutOo8o3/rR/vuM1qkl7+fkQOf15xM2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBh3iaqCSq6W3dvbW8OdTM/KlSsbvQWYUI2r0I+Pj1dhJ1CZarRbiel2/vnPf75KOyEnzdZ5hNapXLN1rnGmQudp8wkbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJcdHhCk33okxH+/6NGzeWrP3sZz+b1s8CoHZSuUhfR0dHydoLL7xw3K/b09Mz5T2RhxRaL9d5hNapjhQaj/B+Tm3pvDn5hA0AAABAYgxsAAAAABJjYAMAAACQGAMbAAAAgMQY2AAAAAAkxl2ijqLeV9G+8sorj2stIqIoipK15cuXV31PUA/t7dOfG/f29lZhJ3B09f53Qjnvv/9+2fVK7qwAx5JC5xHlW9c51ZJC597PqTWdtw6fsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBgDGwAAAIDEuEtURGzZsqXRW6jIHXfc0egtHPXOPjt37ixZe+CBB8o+98EHH6zqnkhfW1tbydr4+HjZ57rzE43Q0dFRk9cdGhoqu/61r33tuF/jaHtztwWmop6tV9J5RPm96ZxKeT8nBzpvfT5hAwAAAJAYAxsAAACAxBjYAAAAACTGwAYAAAAgMS46HBFz585t9BbKev3118uuP/nkk3XeSalyFxc+muuvv77s+tVXX12ydt555011SyTkaBfyLnfxsb6+vhrvBo7fwYMHy67/6U9/Kln70Y9+VPa5H3zwQVX39JGj7a1WPvnJT5as/fOf/6zrHqidcj2V6zyifOu16jyivq2X6zxC663A+/nHvJ+3Lp1/rFU79wkbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIy7RCVi//79JWvf/OY3G7ATqEy5u6zNmjWr7HO/8IUv1Ho7UBM333xzo7dQd+XuSNjb29uAnVAvOv+Y1luXzo/QeGvT+RGt0LlP2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyLDkfEBRdcULK2ZcuWuu7hS1/6Ul1/XiUGBwdr8rpHuzAtzaUoipI1FxcGAACYHp+wAQAAAEiMgQ0AAABAYgxsAAAAABJjYAMAAACQmIoGNuvXr4/ly5fHvHnz4rTTTovVq1fH7t27Jz3nwIEDMTAwECeffHLMnTs3Lr/88hgeHq7qpqGWdE4OdE4OdE4OdE4OdE6uKrpL1NatW2NgYCCWL18e//nPf+KHP/xhXHTRRfH666/HnDlzIiLi5ptvjt///vfx+OOPR2dnZ6xZsyYuu+yyeOGFF2pygGp4//336/azvvrVr9btZ1WqVneDajat2vl0VdLH6tWry67/4x//qNJuqm/79u0la618tyudV6aS/g8cOFCydvDgwbLPfeWVV0rWvvjFLx7/xmro0KFDjd7CtOm8MtPtPKJ86+U6j0ijdZ3r/Fi8n6dD55XReeuoaGDz9NNPT/rzI488Eqeddlrs2rUrzj///BgZGYlf//rX8eijj8ZXvvKViIh4+OGH4zOf+Uxs3769pf/HD61D5+RA5+RA5+RA5+RA5+RqWtewGRkZiYiIhQsXRkTErl274vDhw9Hf3z/xnDPPPDMWL14c27ZtK/saBw8ejH379k16QEp0Tg50Tg50Tg50Tg50Ti6mPLAZHx+Pm266Kc4999w466yzIiJiaGgoZs6cGQsWLJj03K6urhgaGir7OuvXr4/Ozs6Jx6JFi6a6Jag6nZMDnZMDnZMDnZMDnZOTKQ9sBgYG4rXXXouNGzdOawPr1q2LkZGRiceePXum9XpQTTonBzonBzonBzonBzonJxVdw+Yja9asiaeeeiqee+65OOOMMybWu7u749ChQ7F3795J083h4eHo7u4u+1odHR3R0dExlW3UVG9vb9n19vbSGdfOnTvLPveKK64oWXv33Xent7Eq+elPf9roLRz1d5yKHDqvlU2bNpVdT/mfea5/t1nn1Tdr1qzjWotI50J95ZxzzjmN3kLV6Lz6jtZ0uXWd14fOq8/7eXp0Xn06T1tFn7ApiiLWrFkTTzzxRDzzzDOxZMmSSV9ftmxZnHjiibF58+aJtd27d8dbb70VfX191dkx1JjOyYHOyYHOyYHOyYHOyVVFn7AZGBiIRx99NJ588smYN2/exN8H7OzsjNmzZ0dnZ2d85zvfibVr18bChQtj/vz5ceONN0ZfX1+2/+81zUfn5EDn5EDn5EDn5EDn5Kqigc0vf/nLiIi44IILJq0//PDDcfXVV0dExN133x3t7e1x+eWXx8GDB2PVqlVx//33V2WzUA86Jwc6Jwc6Jwc6Jwc6J1cVDWyKovifz5k1a1Zs2LAhNmzYMOVNQSPpnBzonBzonBzonBzonFxN+S5RAAAAANTGlO4SlbPx8fGStZTvfFPurlYRERdeeGGddwIRg4ODJWvl/jMVEbFixYpabwcAACBZPmEDAAAAkBgDGwAAAIDEGNgAAAAAJMbABgAAACAxLjrc4nbu3NnoLSR9UWaOXyX/HP/whz+UXT/55JNL1u6+++4p7wnqZWxsrGRtxowZDdjJ9PT19ZWsHT58uAE7IUXlOo9ovtbLdR6hdY7wfk4OdN46fMIGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiDGwAAAAAEuMuUS3uaHf2KXeV8B07dhz3646Pj5ddX7FixXG/Bq1r1apVjd4CVNXKlStr8rqDg4PTfg134qNadE4OdE4OdN46fMIGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiXHQ4U2NjYyVrLgAFUF/ed8mBzsmBzsmBzuvPJ2wAAAAAEmNgAwAAAJAYAxsAAACAxBjYAAAAACTGwAYAAAAgMQY2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBgDGwAAAIDEGNgAAAAAJMbABgAAACAxBjYAAAAAiTGwAQAAAEjMCY3ewH8riiIiIsbGxhq8E1rdR4191Fw96Zx60Tk50Dk50Dk50Dk5qKTz5AY2o6OjERHx6quvNngn5GJ0dDQ6Ozvr/jMjdE796Jwc6Jwc6Jwc6JwcHE/nbUUjxpfHMD4+Hm+//XbMmzcvRkdHY9GiRbFnz56YP39+o7dWVfv27WvZs0U0x/mKoojR0dHo6emJ9vb6/u1AnbeGZjifzmuvGTqYjmY4n85rrxk6mI5mOJ/Oa68ZOpiOZjifzmuvGTqYjmY4XyWdJ/cJm/b29jjjjDMiIqKtrS0iIubPn5/sL3u6WvlsEemfr96T+4/ovLWkfj6d10crny0i/fPpvD5a+WwR6Z9P5/XRymeLSP98Oq+PVj5bRPrnO97OXXQYAAAAIDEGNgAAAACJSXpg09HREbfddlt0dHQ0eitV18pni2j981VTK/+uWvlsEa1/vmpq5d9VK58tovXPV02t/Ltq5bNFtP75qqmVf1etfLaI1j9fNbXy76qVzxbReudL7qLDAAAAALlL+hM2AAAAADkysAEAAABIjIENAAAAQGIMbAAAAAASk/TAZsOGDfGpT30qZs2aFStXroydO3c2eksVe+655+KSSy6Jnp6eaGtri02bNk36elEUceutt8bpp58es2fPjv7+/njjjTcas9kKrV+/PpYvXx7z5s2L0047LVavXh27d++e9JwDBw7EwMBAnHzyyTF37ty4/PLLY3h4uEE7TpPO06bz6tB52nReHTpPm86rQ+dp03l16DxtOXWe7MDmt7/9baxduzZuu+22+POf/xxLly6NVatWxbvvvtvorVVk//79sXTp0tiwYUPZr995551x7733xgMPPBA7duyIOXPmxKpVq+LAgQN13mnltm7dGgMDA7F9+/b44x//GIcPH46LLroo9u/fP/Gcm2++OX73u9/F448/Hlu3bo233347LrvssgbuOi0613kOdK7zHOhc5znQuc5zoHOdJ6VI1IoVK4qBgYGJP4+NjRU9PT3F+vXrG7ir6YmI4oknnpj48/j4eNHd3V3cddddE2t79+4tOjo6iscee6wBO5yed999t4iIYuvWrUVRHDnLiSeeWDz++OMTz/nrX/9aRESxbdu2Rm0zKTrXeQ50rvMc6FznOdC5znOgc52nJMlP2Bw6dCh27doV/f39E2vt7e3R398f27Zta+DOquvNN9+MoaGhSefs7OyMlStXNuU5R0ZGIiJi4cKFERGxa9euOHz48KTznXnmmbF48eKmPF+16VznOdC5znOgc53nQOc6z4HOdZ6aJAc27733XoyNjUVXV9ek9a6urhgaGmrQrqrvo7O0wjnHx8fjpptuinPPPTfOOuusiDhyvpkzZ8aCBQsmPbcZz1cLOm++c+q8cjpvvnPqvHI6b75z6rxyOm++c+q8cjpvvnO2eucnNHoDtIaBgYF47bXX4vnnn2/0VqBmdE4OdE4OdE4OdE4OWr3zJD9hc8opp8SMGTNKruI8PDwc3d3dDdpV9X10lmY/55o1a+Kpp56KZ599Ns4444yJ9e7u7jh06FDs3bt30vOb7Xy1ovPmOqfOp0bnzXVOnU+NzpvrnDqfGp031zl1PjU6b65z5tB5kgObmTNnxrJly2Lz5s0Ta+Pj47F58+bo6+tr4M6qa8mSJdHd3T3pnPv27YsdO3Y0xTmLoog1a9bEE088Ec8880wsWbJk0teXLVsWJ5544qTz7d69O956662mOF+t6VznOdC5znOgc53nQOc6z4HOdZ6cRl7x+Fg2btxYdHR0FI888kjx+uuvF9dee22xYMGCYmhoqNFbq8jo6Gjx0ksvFS+99FIREcXPf/7z4qWXXir+/ve/F0VRFD/5yU+KBQsWFE8++WTxyiuvFJdeemmxZMmS4sMPP2zwzv+3G264oejs7Cy2bNlSvPPOOxOPDz74YOI5119/fbF48eLimWeeKQYHB4u+vr6ir6+vgbtOi851ngOd6zwHOtd5DnSu8xzoXOcpSXZgUxRFcd999xWLFy8uZs6cWaxYsaLYvn17o7dUsWeffbaIiJLHVVddVRTFkVuq3XLLLUVXV1fR0dFRXHjhhcXu3bsbu+njVO5cEVE8/PDDE8/58MMPi+9973vFSSedVHziE58ovv71rxfvvPNO4zadIJ2nTefVofO06bw6dJ42nVeHztOm8+rQedpy6rytKIpi6p/PAQAAAKDakryGDQAAAEDODGwAAAAAEmNgAwAAAJAYAxsAAACAxBjYAAAAACTGwAYAAAAgMQY2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBgDGwAAAIDE/B8p7caKHn6ulwAAAABJRU5ErkJggg==",
            "text/plain": [
              "<Figure size 1400x1400 with 6 Axes>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import numpy as np\n",
        "\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "%matplotlib inline\n",
        "\n",
        "\n",
        "fig, ax = plt.subplots(1, 6, figsize=(14, 14))\n",
        "percentages = np.array([0, 10, 25, 50, 75, 100], dtype=np.float32)\n",
        "inds = (N_BATCH * BATCH_SIZE * percentages / 100 + 4).astype(int)\n",
        "\n",
        "for i, ax in enumerate(ax.flat):\n",
        "    b = torch.argmax(score_image(decode(train_x[: inds[i], :])), dim=0)\n",
        "    img = decode(train_x[b].view(1, -1)).squeeze().cpu()\n",
        "    ax.imshow(img, alpha=0.8, cmap=\"gray\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "python3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
